{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第5章 决策树\n",
    "\n",
    "参考：https://github.com/fengdu78/lihang-code"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1．分类决策树模型是表示基于特征对实例进行分类的树形结构。决策树可以转换成一个**if-then**规则的集合，也可以看作是定义在特征空间划分上的类的条件概率分布。\n",
    "\n",
    "2．决策树学习旨在构建一个与训练数据拟合很好，并且复杂度小的决策树。因为从可能的决策树中直接选取最优决策树是NP完全问题。现实中采用启发式方法学习次优的决策树。\n",
    "\n",
    "决策树学习算法包括3部分：特征选择、树的生成和树的剪枝。常用的算法有ID3、\n",
    "C4.5和CART。\n",
    "\n",
    "3．特征选择的目的在于选取对训练数据能够分类的特征。特征选择的关键是其准则。常用的准则如下：\n",
    "\n",
    "（1）样本集合$D$对特征$A$的信息增益（ID3）\n",
    "\n",
    "\n",
    "$$g(D, A)=H(D)-H(D|A)$$\n",
    "\n",
    "$$H(D)=-\\sum_{k=1}^{K} \\frac{\\left|C_{k}\\right|}{|D|} \\log _{2} \\frac{\\left|C_{k}\\right|}{|D|}$$\n",
    "\n",
    "$$H(D | A)=\\sum_{i=1}^{n} \\frac{\\left|D_{i}\\right|}{|D|} H\\left(D_{i}\\right)$$\n",
    "\n",
    "其中，$H(D)$是数据集$D$的熵，$H(D_i)$是数据集$D_i$的熵，$H(D|A)$是数据集$D$对特征$A$的条件熵。\t$D_i$是$D$中特征$A$取第$i$个值的样本子集，$C_k$是$D$中属于第$k$类的样本子集。$n$是特征$A$取 值的个数，$K$是类的个数。\n",
    "\n",
    "（2）样本集合$D$对特征$A$的信息增益比（C4.5）\n",
    "\n",
    "\n",
    "$$g_{R}(D, A)=\\frac{g(D, A)}{H(D)}$$\n",
    "\n",
    "\n",
    "其中，$g(D,A)$是信息增益，$H(D)$是数据集$D$的熵。\n",
    "\n",
    "（3）样本集合$D$的基尼指数（CART）\n",
    "\n",
    "$$\\operatorname{Gini}(D)=1-\\sum_{k=1}^{K}\\left(\\frac{\\left|C_{k}\\right|}{|D|}\\right)^{2}$$\n",
    "\n",
    "特征$A$条件下集合$D$的基尼指数：\n",
    "\n",
    " $$\\operatorname{Gini}(D, A)=\\frac{\\left|D_{1}\\right|}{|D|} \\operatorname{Gini}\\left(D_{1}\\right)+\\frac{\\left|D_{2}\\right|}{|D|} \\operatorname{Gini}\\left(D_{2}\\right)$$\n",
    " \n",
    "4．决策树的生成。通常使用信息增益最大、信息增益比最大或基尼指数最小作为特征选择的准则。决策树的生成往往通过计算信息增益或其他指标，从根结点开始，递归地产生决策树。这相当于用信息增益或其他准则不断地选取局部最优的特征，或将训练集分割为能够基本正确分类的子集。\n",
    "\n",
    "5．决策树的剪枝。由于生成的决策树存在过拟合问题，需要对它进行剪枝，以简化学到的决策树。决策树的剪枝，往往从已生成的树上剪掉一些叶结点或叶结点以上的子树，并将其父结点或根结点作为新的叶结点，从而简化生成的决策树。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-10-17T12:35:54.514006Z",
     "start_time": "2019-10-17T12:35:52.534622Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "from collections import Counter\n",
    "import math\n",
    "from math import log\n",
    "import pprint"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 书上题目5.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-10-19T03:49:06.439359Z",
     "start_time": "2019-10-19T03:49:06.429405Z"
    }
   },
   "outputs": [],
   "source": [
    "def create_data():\n",
    "    datasets = [['青年', '否', '否', '一般', '否'],\n",
    "               ['青年', '否', '否', '好', '否'],\n",
    "               ['青年', '是', '否', '好', '是'],\n",
    "               ['青年', '是', '是', '一般', '是'],\n",
    "               ['青年', '否', '否', '一般', '否'],\n",
    "               ['中年', '否', '否', '一般', '否'],\n",
    "               ['中年', '否', '否', '好', '否'],\n",
    "               ['中年', '是', '是', '好', '是'],\n",
    "               ['中年', '否', '是', '非常好', '是'],\n",
    "               ['中年', '否', '是', '非常好', '是'],\n",
    "               ['老年', '否', '是', '非常好', '是'],\n",
    "               ['老年', '否', '是', '好', '是'],\n",
    "               ['老年', '是', '否', '好', '是'],\n",
    "               ['老年', '是', '否', '非常好', '是'],\n",
    "               ['老年', '否', '否', '一般', '否'],\n",
    "               ]\n",
    "    feature_names = ['年龄', '有工作', '有自己的房子', '信贷情况', '类别']\n",
    "    # 返回数据集和每个维度的名称\n",
    "    return datasets, feature_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-10-19T03:49:07.344764Z",
     "start_time": "2019-10-19T03:49:07.341772Z"
    }
   },
   "outputs": [],
   "source": [
    "datasets, feature_names = create_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-10-19T03:49:09.931805Z",
     "start_time": "2019-10-19T03:49:09.917850Z"
    }
   },
   "outputs": [],
   "source": [
    "train_data = pd.DataFrame(datasets, columns=feature_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-10-19T03:49:11.812141Z",
     "start_time": "2019-10-19T03:49:11.792194Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>年龄</th>\n",
       "      <th>有工作</th>\n",
       "      <th>有自己的房子</th>\n",
       "      <th>信贷情况</th>\n",
       "      <th>类别</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>青年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>一般</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>青年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>好</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>青年</td>\n",
       "      <td>是</td>\n",
       "      <td>否</td>\n",
       "      <td>好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>青年</td>\n",
       "      <td>是</td>\n",
       "      <td>是</td>\n",
       "      <td>一般</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>青年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>一般</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>中年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>一般</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>中年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>好</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>中年</td>\n",
       "      <td>是</td>\n",
       "      <td>是</td>\n",
       "      <td>好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>中年</td>\n",
       "      <td>否</td>\n",
       "      <td>是</td>\n",
       "      <td>非常好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>中年</td>\n",
       "      <td>否</td>\n",
       "      <td>是</td>\n",
       "      <td>非常好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>老年</td>\n",
       "      <td>否</td>\n",
       "      <td>是</td>\n",
       "      <td>非常好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>老年</td>\n",
       "      <td>否</td>\n",
       "      <td>是</td>\n",
       "      <td>好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>老年</td>\n",
       "      <td>是</td>\n",
       "      <td>否</td>\n",
       "      <td>好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>老年</td>\n",
       "      <td>是</td>\n",
       "      <td>否</td>\n",
       "      <td>非常好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>老年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>一般</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    年龄 有工作 有自己的房子 信贷情况 类别\n",
       "0   青年   否      否   一般  否\n",
       "1   青年   否      否    好  否\n",
       "2   青年   是      否    好  是\n",
       "3   青年   是      是   一般  是\n",
       "4   青年   否      否   一般  否\n",
       "5   中年   否      否   一般  否\n",
       "6   中年   否      否    好  否\n",
       "7   中年   是      是    好  是\n",
       "8   中年   否      是  非常好  是\n",
       "9   中年   否      是  非常好  是\n",
       "10  老年   否      是  非常好  是\n",
       "11  老年   否      是    好  是\n",
       "12  老年   是      否    好  是\n",
       "13  老年   是      否  非常好  是\n",
       "14  老年   否      否   一般  否"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-10-17T13:15:16.280602Z",
     "start_time": "2019-10-17T13:15:16.268632Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "# 定义树节点类, 分类树\n",
    "class Node:\n",
    "    def __init__(self, root=True, label=None, feature_name=None, feature=None):\n",
    "        # root 标记是否为单结点树（即叶子结点）\n",
    "        self.root = root\n",
    "        # label 为当前节点的类别标记\n",
    "        self.label = label\n",
    "        # feature_name 为当前节点的特征名\n",
    "        self.feature_name = feature_name\n",
    "        # 子树\n",
    "        self.tree = {}\n",
    "        self.result = {\n",
    "            'lable': self.label,\n",
    "            'feature': self.feature_name,\n",
    "            'tree': self.tree\n",
    "        }\n",
    "\n",
    "    def __repr__(self):\n",
    "        \"\"\"\n",
    "        重写 __repr__ 用于显示对象，也可以 使用 __str__\n",
    "        dt = Node()\n",
    "        print(dt) 就可以显示自己定义的信息了\n",
    "        这里  self.result 中包含 self.tree，所以会递归显示所有子树信息\n",
    "        :return:\n",
    "        \"\"\"\n",
    "        return '{}'.format(self.result)\n",
    "\n",
    "    def add_node(self, val, node):\n",
    "        \"\"\"\n",
    "        增加树节点\n",
    "        \"\"\"\n",
    "        self.tree[val] = node\n",
    "\n",
    "    def predict(self, features):\n",
    "        if self.root is True:\n",
    "            return self.label\n",
    "        return self.tree[features[self.feature]].predict(features)\n",
    "\n",
    "\n",
    "class DTree:\n",
    "    def __init__(self, epsilon=0.1):\n",
    "        self.epsilon = epsilon\n",
    "        self._tree = {}\n",
    "\n",
    "    # 熵\n",
    "    @staticmethod\n",
    "    def calc_ent(datasets):\n",
    "        data_length = len(datasets)\n",
    "        label_count = {}\n",
    "        for i in range(data_length):\n",
    "            label = datasets[i][-1]\n",
    "            if label not in label_count:\n",
    "                label_count[label] = 0\n",
    "            label_count[label] += 1\n",
    "        ent = -sum([(p / data_length) * log(p / data_length, 2)\n",
    "                    for p in label_count.values()])\n",
    "        return ent\n",
    "\n",
    "    # 经验条件熵 H(D|A)\n",
    "    def cond_ent(self, datasets, index=0):\n",
    "        \"\"\"\n",
    "        :param index: 当前特征列 A 在 datasets 中的索引\n",
    "        \"\"\"\n",
    "        data_length = len(datasets)\n",
    "        # 其中每个元素存储当前特征值时的所有样本\n",
    "        feature_sets = {}\n",
    "        for i in range(data_length):\n",
    "            feature = datasets[i][index]\n",
    "            if feature not in feature_sets:\n",
    "                feature_sets[feature] = []\n",
    "            # 当前样本 datasets[i] 加入列表中\n",
    "            feature_sets[feature].append(datasets[i])\n",
    "        # H(D|A)\n",
    "        cond_ent = sum([(len(p) / data_length) * self.calc_ent(p) for p in feature_sets.values()])\n",
    "        return cond_ent\n",
    "\n",
    "    # 信息增益\n",
    "    @staticmethod\n",
    "    def info_gain(ent, cond_ent):\n",
    "        return ent - cond_ent\n",
    "\n",
    "    def info_gain_train(self, datasets):\n",
    "        count = len(datasets[0]) - 1\n",
    "        ent = self.calc_ent(datasets)\n",
    "        best_feature = []\n",
    "        # 对所有特征计算信息增益\n",
    "        for c in range(count):\n",
    "            c_info_gain = self.info_gain(ent, self.cond_ent(datasets, index=c))\n",
    "            best_feature.append((c, c_info_gain))\n",
    "\n",
    "        # 选择信息增益最大的特征作为根节点\n",
    "        best_ = max(best_feature, key=lambda x: x[-1])\n",
    "        return best_\n",
    "\n",
    "    def train(self, train_data):\n",
    "        \"\"\"\n",
    "        :param train_data: 训练数据， pandas 格式，最后一列为 label\n",
    "        :return: 决策树\n",
    "        \"\"\"\n",
    "        # y_train: 最后一列 label ， features : 特征名\n",
    "        _, y_train, features = train_data.iloc[:, :-1], train_data.iloc[:, -1], train_data.columns[: -1]\n",
    "\n",
    "        # 1. 若 D 中实例属于同一类 Ck， 则 T 为单节点树，并将类 Ck 作为结点的类标记，返回T\n",
    "        if len(y_train.value_counts()) == 1:\n",
    "            return Node(root=True, label=y_train.iloc[0])\n",
    "        # 2. 若特征集 A 为空，则T为单节点树，将 D 中实例数最大的类 Ck 作为该结点的类标记，返回 T\n",
    "        if len(features) == 0:\n",
    "            # 对 Series 排序，去索引第一的作为类标\n",
    "            return Node(root=True, label=y_train.value_counts().sort_values(ascending=False).index[0])\n",
    "        # 3. 计算 A 中各特征对 D 的最大信息增益，选择 Ag\n",
    "        max_feature, max_info_gain = self.info_gain_train(np.array(train_data))\n",
    "        max_feature_name = features[max_feature]\n",
    "\n",
    "        # 4. Ag 的信息增益小于阈值 eta，则置T为单节点树，并将D中是实例数最大的类Ck作为该节点的类标记，返回T\n",
    "        if max_info_gain < self.epsilon:\n",
    "            return Node(root=True, label=y_train.value_counts().sort_values(ascending=False).index[0])\n",
    "\n",
    "        # 5. 构建 Ag 子集，对每一个 Ag = ai 将D分割为若干个子集（相当于多个子树），子集中实例数最大的类别作为当前节点\n",
    "        node_tree = Node(root=False, feature_name=max_feature_name)\n",
    "        # 获取特征 Ag 的所有取值\n",
    "        feature_list = train_data[max_feature_name].unique()\n",
    "        for f in feature_list:\n",
    "            # 获取子集，删掉当前特征\n",
    "            sub_train_df = train_data.loc[train_data[max_feature_name] == f].drop([max_feature_name], axis=1)\n",
    "\n",
    "            # 6. 递归调用，生成树\n",
    "            sub_tree = self.train(sub_train_df)\n",
    "            node_tree.add_node(f, sub_tree)\n",
    "\n",
    "        return node_tree\n",
    "\n",
    "    def fit(self, train_data, label):\n",
    "        train_data['label'] = label\n",
    "        self._tree = self.train(train_data)\n",
    "        return self._tree\n",
    "\n",
    "    def predict(self, X_test):\n",
    "        return self._tree.predict(X_test)\n",
    "\n",
    "    def score(self, test_data, test_label):\n",
    "        \"\"\"\n",
    "        :param test_data: 测试数据\n",
    "        :param test_label: 测试数据标签\n",
    "        :return:\n",
    "        \"\"\"\n",
    "        predict_label = self.predict(test_data)\n",
    "\n",
    "        error_cnt = 0\n",
    "        for k in range(len(predict_label)):\n",
    "            # 记录误分类数\n",
    "            if predict_label[k] != test_label[k]:\n",
    "                error_cnt += 1\n",
    "        # 正确率 = 1 - 错误分类样本数 / 总样本数\n",
    "        print('accuracy: ', 1 - error_cnt / len(predict_label))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-10-17T13:15:17.732207Z",
     "start_time": "2019-10-17T13:15:17.727219Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "特征(年龄) - info_gain - 0.083\n",
      "特征(有工作) - info_gain - 0.324\n",
      "特征(有自己的房子) - info_gain - 0.420\n",
      "特征(信贷情况) - info_gain - 0.363\n",
      "(2, 0.4199730940219749)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'特征(有自己的房子)的信息增益最大，选择为根节点特征'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "info_gain_train(np.array(datasets), labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 利用ID3算法生成决策树，例5.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-10-19T09:15:28.244946Z",
     "start_time": "2019-10-19T09:15:28.213035Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "# 定义树节点类, 分类树\n",
    "class Node:\n",
    "    def __init__(self, root=True, label=None, feature_name=None):\n",
    "        # root 标记是否为单结点树（即叶子结点）\n",
    "        self.root = root\n",
    "        # label 为当前节点的类别标记\n",
    "        self.label = label\n",
    "        # feature_name 为当前节点的特征名\n",
    "        self.feature_name = feature_name\n",
    "\n",
    "        # 子树，字典形式，（特征取值：子树）\n",
    "        self.tree = {}\n",
    "        self.result = {\n",
    "            'lable': self.label,\n",
    "            'feature': self.feature_name,\n",
    "            'tree': self.tree\n",
    "        }\n",
    "\n",
    "    def __repr__(self):\n",
    "        return '{}'.format(self.result)\n",
    "\n",
    "    def add_node(self, val, node):\n",
    "        \"\"\"\n",
    "        :param val: 特征的取值\n",
    "        :param node:\n",
    "        :return:\n",
    "        \"\"\"\n",
    "        \"\"\"\n",
    "        增加树节点\n",
    "        \"\"\"\n",
    "        self.tree[val] = node\n",
    "\n",
    "    def predict(self, x_sample):\n",
    "        \"\"\"\n",
    "        :param x_sample: 一行数据 Series 格式，能通过特征名索引特征值\n",
    "        :return: 当前样本的预测值\n",
    "        \"\"\"\n",
    "        if self.root is True:\n",
    "            return self.label\n",
    "        # x_sample[self.feature_name] 当前特征下样本实例的取值\n",
    "        return self.tree[x_sample[self.feature_name]].predict(x_sample)\n",
    "\n",
    "\n",
    "class DTree:\n",
    "    def __init__(self, epsilon=0.1):\n",
    "        self.epsilon = epsilon\n",
    "        self._tree = {}\n",
    "\n",
    "    # 熵\n",
    "    @staticmethod\n",
    "    def calc_ent(datasets):\n",
    "        data_length = len(datasets)\n",
    "        label_count = {}\n",
    "        for i in range(data_length):\n",
    "            label = datasets[i][-1]\n",
    "            if label not in label_count:\n",
    "                label_count[label] = 0\n",
    "            label_count[label] += 1\n",
    "        ent = -sum([(p / data_length) * log(p / data_length, 2)\n",
    "                    for p in label_count.values()])\n",
    "        return ent\n",
    "\n",
    "    # 经验条件熵 H(D|A)\n",
    "    def cond_ent(self, datasets, index=0):\n",
    "        \"\"\"\n",
    "        :param index: 当前特征列 A 在 datasets 中的索引\n",
    "        \"\"\"\n",
    "        data_length = len(datasets)\n",
    "        # 其中每个元素存储当前特征值时的所有样本\n",
    "        feature_sets = {}\n",
    "        for i in range(data_length):\n",
    "            feature = datasets[i][index]\n",
    "            if feature not in feature_sets:\n",
    "                feature_sets[feature] = []\n",
    "            # 当前样本 datasets[i] 加入列表中\n",
    "            feature_sets[feature].append(datasets[i])\n",
    "        # H(D|A)\n",
    "        cond_ent = sum([(len(p) / data_length) * self.calc_ent(p) for p in feature_sets.values()])\n",
    "        return cond_ent\n",
    "\n",
    "    # 信息增益\n",
    "    @staticmethod\n",
    "    def info_gain(ent, cond_ent):\n",
    "        return ent - cond_ent\n",
    "\n",
    "    def info_gain_train(self, datasets):\n",
    "        count = len(datasets[0]) - 1\n",
    "        ent = self.calc_ent(datasets)\n",
    "        best_feature = []\n",
    "        # 对所有特征计算信息增益\n",
    "        for c in range(count):\n",
    "            c_info_gain = self.info_gain(ent, self.cond_ent(datasets, index=c))\n",
    "            best_feature.append((c, c_info_gain))\n",
    "\n",
    "        # 选择信息增益最大的特征作为根节点\n",
    "        best_ = max(best_feature, key=lambda x: x[-1])\n",
    "        return best_\n",
    "\n",
    "    def train(self, train_data):\n",
    "        \"\"\"\n",
    "        :param train_data: 训练数据， pandas 格式，最后一列为 label\n",
    "        :return: 决策树\n",
    "        \"\"\"\n",
    "        # y_train: 最后一列 label ， features : 特征名\n",
    "        _, y_train, features = train_data.iloc[:, :-1], train_data.iloc[:, -1], train_data.columns[: -1]\n",
    "\n",
    "        # 1. 若 D 中实例属于同一类 Ck， 则 T 为单节点树，并将类 Ck 作为结点的类标记，返回T\n",
    "        if len(y_train.value_counts()) == 1:\n",
    "            return Node(root=True, label=y_train.iloc[0])\n",
    "        # 2. 若特征集 A 为空，则T为单节点树，将 D 中实例数最大的类 Ck 作为该结点的类标记，返回 T\n",
    "        if len(features) == 0:\n",
    "            # 对 Series 排序，去索引第一的作为类标\n",
    "            return Node(root=True, label=y_train.value_counts().sort_values(ascending=False).index[0])\n",
    "        # 3. 计算 A 中各特征对 D 的最大信息增益，选择 Ag\n",
    "        max_feature, max_info_gain = self.info_gain_train(np.array(train_data))\n",
    "        max_feature_name = features[max_feature]\n",
    "\n",
    "        # 4. Ag 的信息增益小于阈值 eta，则置T为单节点树，并将D中是实例数最大的类Ck作为该节点的类标记，返回T\n",
    "        if max_info_gain < self.epsilon:\n",
    "            return Node(root=True, label=y_train.value_counts().sort_values(ascending=False).index[0])\n",
    "\n",
    "        # 5. 构建 Ag 子集，对每一个 Ag = ai 将D分割为若干个子集（相当于多个子树），子集中实例数最大的类别作为当前节点\n",
    "        node_tree = Node(root=False, feature_name=max_feature_name)\n",
    "        # 获取特征 Ag 的所有取值\n",
    "        feature_list = train_data[max_feature_name].unique()\n",
    "        for f in feature_list:\n",
    "            # 获取子集，删掉当前特征\n",
    "            sub_train_df = train_data.loc[train_data[max_feature_name] == f].drop([max_feature_name], axis=1)\n",
    "\n",
    "            # 6. 递归调用，生成树\n",
    "            sub_tree = self.train(sub_train_df)\n",
    "            node_tree.add_node(f, sub_tree)\n",
    "\n",
    "        return node_tree\n",
    "\n",
    "    def fit(self, train_data):\n",
    "        self._tree = self.train(train_data)\n",
    "        return self._tree\n",
    "\n",
    "    def predict_one_sample(self, x_vec):\n",
    "        \"\"\"\n",
    "        预测一条样本\n",
    "        :param x_vec: 样本向量\n",
    "        :return:\n",
    "        \"\"\"\n",
    "        return self._tree.predict(x_vec)\n",
    "\n",
    "    def predict(self, X_test):\n",
    "        \"\"\"\n",
    "        预测结果\n",
    "        :param X_test: DataFrame 格式\n",
    "        :return:\n",
    "        \"\"\"\n",
    "        result = []\n",
    "        for i in range(X_test.shape[0]):\n",
    "            result.append(self.predict_one_sample(X_test.iloc[i]))\n",
    "        return result\n",
    "\n",
    "    def score(self, test_data, test_label):\n",
    "        \"\"\"\n",
    "        :param test_data: 测试数据\n",
    "        :param test_label: 测试数据标签\n",
    "        :return:\n",
    "        \"\"\"\n",
    "        predict_label = self.predict(test_data)\n",
    "\n",
    "        error_cnt = 0\n",
    "        for k in range(len(predict_label)):\n",
    "            # 记录误分类数\n",
    "            if predict_label[k] != test_label[k]:\n",
    "                error_cnt += 1\n",
    "        # 正确率 = 1 - 错误分类样本数 / 总样本数\n",
    "        print('accuracy: ', 1 - error_cnt / len(predict_label))\n",
    "\n",
    "def create_data():\n",
    "    datasets = [['青年', '否', '否', '一般', '否'],\n",
    "               ['青年', '否', '否', '好', '否'],\n",
    "               ['青年', '是', '否', '好', '是'],\n",
    "               ['青年', '是', '是', '一般', '是'],\n",
    "               ['青年', '否', '否', '一般', '否'],\n",
    "               ['中年', '否', '否', '一般', '否'],\n",
    "               ['中年', '否', '否', '好', '否'],\n",
    "               ['中年', '是', '是', '好', '是'],\n",
    "               ['中年', '否', '是', '非常好', '是'],\n",
    "               ['中年', '否', '是', '非常好', '是'],\n",
    "               ['老年', '否', '是', '非常好', '是'],\n",
    "               ['老年', '否', '是', '好', '是'],\n",
    "               ['老年', '是', '否', '好', '是'],\n",
    "               ['老年', '是', '否', '非常好', '是'],\n",
    "               ['老年', '否', '否', '一般', '否'],\n",
    "               ]\n",
    "    feature_names = ['年龄', '有工作', '有自己的房子', '信贷情况', '类别']\n",
    "    # 返回数据集和每个维度的名称\n",
    "    return datasets, feature_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-10-19T09:16:16.962621Z",
     "start_time": "2019-10-19T09:16:16.947661Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'lable': None, 'feature': '有自己的房子', 'tree': {'否': {'lable': None, 'feature': '有工作', 'tree': {'否': {'lable': '否', 'feature': None, 'tree': {}}, '是': {'lable': '是', 'feature': None, 'tree': {}}}}, '是': {'lable': '是', 'feature': None, 'tree': {}}}}\n",
      "accuracy:  1.0\n"
     ]
    }
   ],
   "source": [
    "datasets, labels = create_data()\n",
    "data_df = pd.DataFrame(datasets, columns=labels)\n",
    "dt = DTree()\n",
    "tree = dt.fit(data_df)\n",
    "print(tree)\n",
    "X_train = data_df.iloc[:, :-1]\n",
    "y_train = data_df.iloc[:, -1]\n",
    "dt.score(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-10-19T09:17:06.294759Z",
     "start_time": "2019-10-19T09:17:06.289772Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "年龄        中年\n",
      "有工作        否\n",
      "有自己的房子     否\n",
      "信贷情况      一般\n",
      "Name: 5, dtype: object\n",
      "label:否\n",
      "predict:否\n"
     ]
    }
   ],
   "source": [
    "print(X_train.iloc[5])\n",
    "print(\"label:\" + y_train[5])\n",
    "print(\"predict:\" + dt.predict_one_sample(X_train.iloc[5]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-10-19T09:24:07.501278Z",
     "start_time": "2019-10-19T09:24:07.498287Z"
    }
   },
   "outputs": [],
   "source": [
    "a = np.array([[1,2],[3,4]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-10-19T09:24:21.933454Z",
     "start_time": "2019-10-19T09:24:21.928438Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   0  1\n",
      "0  1  2\n",
      "1  3  4\n"
     ]
    }
   ],
   "source": [
    "print(pd.DataFrame(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
